Disjoint Set¶
Disjoint set is a data structure used to store distinct non-overlapping sets. It supports following operations:
Union: Merge two distinct sets into a single set
Find: Find the set to which an element belongs
Because of the operations, it is also called as union-find data structure. Disjoint set also has a function makeSet
or createSet
to initialize the set and/or introduce new set elements.
Disjoint sets are used by many applications and algorithms, and they play a special role in Kruskal’s Minimum Spanning tree graph algorithm. Here we are going to look at different implementations and optimizations of disjoint sets. For a deatiled theoretical explanation of disjoint sets, look at wikipedia or some other resource like introduction to algorithms
.
Before we begin implementation, some important things to note:
The sets have to be disjoint. Example: if set A = {1,2,3} and B = {3,4} then A and B are not disjoint as both contain element 3. if set A = {1,2} and B = {3,4} then A and B are disjoint as there’s no common element.
Disjoint sets can be implemented using arrays, hashmaps, linked lists (can be more options.. help would be appreciated).
Both array and hashmap are used to store a tree structure representations of the disjoint set. Which is also called as disjoint set-forest(tree).
In this chapter, we will see different implementations of the disjoint set forest.
Disjoint Set: Using arrays¶
In this implementation we will use array(list type in python) for storing the parent pointers. We are going to just create a working implementation without and optimizations.
We are assuming following:
The maximum size N. So number of elements will be defined while creating the set(inside constructor) and cannot be updated later.
The elements of the set are assumed to be integers in the range
[0-N)
i.e. 0 included and N excluded. Where N denotes size of the set
class DisjointSetArray:
def __init__(self, size):
"""
Create a disjoint set object with each element in it's own subset.
i.e. parent[i] = i
"""
self.__parents = list(range(size))
self.__size = size
def find(self, v):
"""
Returns the topmost parent of 'v' or the root of the set 'v' is a part of
Finds the parent recursively till parent[v] == v is not true
"""
self.__is_element_valid(v)
if self.__parents[v] == v:
return v
return self.find(self.__parents[v])
def union(self, v1, v2):
"""
Makes v2 the parent of v1
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
self.__parents[p1] = p2
def __is_element_valid(self, v):
if v >= self.__size or v < 0:
raise ValueError(f"Invalid set element: {v}")
In above implementations, find()
operation will have a worst case complexity of O(n)
in case of a skew tree structure.
Also the space complexity would be O(n)
due to the recursive calls. We can very easily replace the find operation with an iterative one and make the space complexity O(1). See iterative implementation of find function below.
def find(self, v):
"""
Returns the topmost parent of 'v' or the root of the set 'v' is a part of
Finds the parent iteratively till parent[v] == v is not true
"""
self.__is_element_valid(v)
while not self.__parents[v] == v:
v = self.__parents[v]
return v
Below are some unit test cases for the simple implementation of disjoint set we did above.
import unittest
class DisjointSetV1Test(unittest.TestCase):
def test_1(self):
d = DisjointSetArray(10)
# Initially each element's parent is itself
self.assertEqual(d.find(1),1)
self.assertEqual(d.find(2),2)
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find(-1)
# create set {1,2}
d.union(1,2)
self.assertEqual(d.find(1), d.find(2))
# create set {3,4}
d.union(3,4)
self.assertEqual(d.find(3), d.find(4))
# create set {1,2,3,4}
d.union(1,4)
# after union of 1 and 4, 2 and 3 should also be in the same set
self.assertEqual(d.find(2), d.find(3))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetV1Test))
test_1 (__main__.DisjointSetV1Test) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
The union operation still has a time complexity of O(n) since it does a find on the nodes initially to find their parent. But the actual union step self.__parents[p1] = p2
takes O(1) time since we are not updating the children of p1
.
Also if the union operation is called directly on the root nodes/values, it would actually be O(1)
.
This implementation can be called as Quick-Union.
However this may result in completely skew trees, which would lead to slow find operations.
For example below set of untion operations generate a skew tree.
d = DisjointSetV1(5)
d.union(0,2) # 0 -> 1
d.union(1,2) # 0 -> 1 -> 2
d.union(2,3) # 0 -> 1 -> 2 -> 3
d.union(3,4) # 0 -> 1 -> 2 -> 3 -> 4
Disjoint Set with Path Compression¶
Here the union()
operation stays the same. But we alter find()
to update the parent poniters of all nodes that it traverses on it’s way to finding the root of the set. This performs better for repeated find operations and has a time complexity of O((m+n)log(n))
for m operations on a set of size n.
The amortized time complexity for a find operation turns out to be O(log n)
.
class DisjointSetPathCompression:
def __init__(self, size):
"""
Create a disjoint set object with each element in it's own subset.
i.e. parent[i] = i
"""
self.__parents = list(range(size))
self.__size = size
def find(self, v):
"""
This implementation of find uses simple path compression.
"""
self.__is_element_valid(v)
root = v
while not self.__parents[root] == root:
root = self.__parents[root]
# path compression step
while v != root:
self.__parents[v], v = root, self.__parents[v]
return root
def union(self, v1, v2):
"""
Makes v2 the parent of v1
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
self.__parents[p1] = p2
def __is_element_valid(self, v):
if v >= self.__size or v < 0:
raise ValueError(f"Invalid set element: {v}")
class DisjointSetPathCompressionTest(unittest.TestCase):
def test_1(self):
d = DisjointSetPathCompression(10)
# Initially each element's parent is itself
self.assertEqual(d.find(1),1)
self.assertEqual(d.find(2),2)
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find(-1)
# create set {1,2}
d.union(1,2)
self.assertEqual(d.find(1), d.find(2))
# create set {3,4}
d.union(3,4)
self.assertEqual(d.find(3), d.find(4))
# create set {1,2,3,4}
d.union(1,4)
# after union of 1 and 4, 2 and 3 should also be in the same set
self.assertEqual(d.find(2), d.find(3))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetPathCompressionTest))
test_1 (__main__.DisjointSetPathCompressionTest) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.002s
OK
Disjoint Set:Using union by size/weight¶
Here the find()
operation stays the same and we use the one without path compression. The union()
Operation now checks the size of the nodes. It then makes the tree with lesser size the sub-tree of the one with bigger size. If sizes are same, then any option works.
This approach makes the find
operations run in O(log n) time complexity
class DisjointSetWeighed:
def __init__(self, size):
"""
Create a disjoint set object with each element in it's own subset.
Initialize a size/weight array with initial weight of 1 for each element
i.e. parent[i] = i
"""
self.__parents = list(range(size))
self.__weights = [1 for _ in range(size)]
self.__size = size
def find(self, v):
"""
Returns the topmost parent of 'v' or the root of the set 'v' is a part of
Finds the parent iteratively till parent[v] == v is not true
"""
self.__is_element_valid(v)
while not self.__parents[v] == v:
v = self.__parents[v]
return v
def union(self, v1, v2):
"""
Makes v2 the parent of v1
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
# swap p1 and p2 if size of p1 more than p2, as we will make p2 parent of p1
if self.__weights[p1] > self.__weights[p2]:
p1, p2 = p2, p1
# make p2 parent of p1 and update size of p2
self.__parents[p1] = p2
self.__weights[p2] += self.__weights[p1]
def __is_element_valid(self, v):
if v >= self.__size or v < 0:
raise ValueError(f"Invalid set element: {v}")
class DisjointSetWeighedTest(unittest.TestCase):
def test_1(self):
d = DisjointSetWeighed(10)
# Initially each element's parent is itself
self.assertEqual(d.find(1),1)
self.assertEqual(d.find(2),2)
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find(-1)
# create set {1,2}
d.union(1,2)
self.assertEqual(d.find(1), d.find(2))
# create set {3,4}
d.union(3,4)
self.assertEqual(d.find(3), d.find(4))
# create set {1,2,3,4}
d.union(1,4)
# after union of 1 and 4, 2 and 3 should also be in the same set
self.assertEqual(d.find(2), d.find(3))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetWeighedTest))
test_1 (__main__.DisjointSetWeighedTest) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
Disjoint Set using union by rank/height¶
This is similar to Union by size. Instead of maintaining number of nodes in the tree, we maintain the height of the current tree. While finding union, we make the tree with lesser height as subtree of the one with larger height. If both trees have same height, then we increase the hight of the tree which becomes parent by +1.
Again we change the union
method only.
class DisjointSetRanked:
def __init__(self, size):
"""
Create a disjoint set object with each element in it's own subset.
Initialize a height array with initial height of 1 for each element
i.e. parent[i] = i
"""
self.__parents = list(range(size))
self.__heights = [1 for _ in range(size)]
self.__size = size
def find(self, v):
"""
Returns the topmost parent of 'v' or the root of the set 'v' is a part of
Finds the parent iteratively till parent[v] == v is not true
"""
self.__is_element_valid(v)
while not self.__parents[v] == v:
v = self.__parents[v]
return v
def union(self, v1, v2):
"""
Merges the sets for v1 and v2 based on height/rank
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
if self.__heights[p1] > self.__heights[p2]:
self.__parents[p2] = p1
elif self.__heights[p2] > self.__heights[p1]:
self.__parents[p1] = p2
else:
self.__parents[p1] = p2
self.__heights[p2] += 1
def __is_element_valid(self, v):
if v >= self.__size or v < 0:
raise ValueError(f"Invalid set element: {v}")
class DisjointSetRankedTest(unittest.TestCase):
def test_1(self):
d = DisjointSetRanked(10)
# Initially each element's parent is itself
self.assertEqual(d.find(1),1)
self.assertEqual(d.find(2),2)
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find(-1)
# create set {1,2}
d.union(1,2)
self.assertEqual(d.find(1), d.find(2))
# create set {3,4}
d.union(3,4)
self.assertEqual(d.find(3), d.find(4))
# create set {1,2,3,4}
d.union(1,4)
# after union of 1 and 4, 2 and 3 should also be in the same set
self.assertEqual(d.find(2), d.find(3))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetRankedTest))
test_1 (__main__.DisjointSetRankedTest) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
Disjoint Set with union by size + path compression¶
Using union by size and path conpression gives the best amortized time complexity. We are not going to discuss the computation of the time complexity here as the idea is to focus on implementation.
However theoretically the time complexity is time complexity is O(α(n))
, where α(n) is the inverse Ackermann function, which grows very slowly. And α(n)
can be assumed to be equal to 4.
Hence the amortized time complexity of Union by size + Path Compression is nearly constant
class DisjointSetWeightedCompressed:
def __init__(self, size):
"""
Create a disjoint set object with each element in it's own subset.
Initialize a size/weight array with initial weight of 1 for each element
i.e. parent[i] = i
"""
self.__parents = list(range(size))
self.__weights = [1 for _ in range(size)]
self.__size = size
def find(self, v):
"""
Returns the topmost parent of 'v' or the root of the set 'v' is a part of
Finds the parent iteratively till parent[v] == v is not true
Does path compression
"""
self.__is_element_valid(v)
root = v
while not self.__parents[root] == root:
root = self.__parents[root]
# path compression step
while v != root:
self.__parents[v], v = root, self.__parents[v]
return root
def union(self, v1, v2):
"""
Makes v2 the parent of v1
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
# swap p1 and p2 if size of p1 more than p2, as we will make p2 parent of p1
if self.__weights[p1] > self.__weights[p2]:
p1, p2 = p2, p1
# make p2 parent of p1 and update size of p2
self.__parents[p1] = p2
self.__weights[p2] += self.__weights[p1]
def __is_element_valid(self, v):
if v >= self.__size or v < 0:
raise ValueError(f"Invalid set element: {v}")
class DisjointSetWeightedCompressedTest(unittest.TestCase):
def test_1(self):
d = DisjointSetWeightedCompressed(10)
# Initially each element's parent is itself
self.assertEqual(d.find(1),1)
self.assertEqual(d.find(2),2)
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find(-1)
# create set {1,2}
d.union(1,2)
self.assertEqual(d.find(1), d.find(2))
# create set {3,4}
d.union(3,4)
self.assertEqual(d.find(3), d.find(4))
# create set {1,2,3,4}
d.union(1,4)
# after union of 1 and 4, 2 and 3 should also be in the same set
self.assertEqual(d.find(2), d.find(3))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetWeightedCompressedTest))
test_1 (__main__.DisjointSetWeightedCompressedTest) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.002s
OK
The above implementation would suffice for most of your interiew needs and you won’t need to go more deeper. But still if you feel intrigued, feel free to jump to other resources for explanations and more optimizations.
The Final Disjoint Set: Union by size + path compression + hashmaps¶
All the above implementations of disjoint set use in integer array for storing the nodes/element, the heights, the weights etc. However in many practical usecases, your data might not be just numbers. Instead it can be strings, custom objects etc. So we are going to define a disjoint set that uses hashmaps (dict in python) to keep mapping of nodes and their parents; and also for storing size of each set.
In this implementation, we are going to provide an additional make_set()
method to allow addition of new values to the disjoint set.
Note: the elements should be hashable
class DisjointSet:
def __init__(self, elements):
"""
Create a disjoint set object with no elements
"""
self.__parents = {element: element for element in elements} # dict<node, parent node>
self.__weights = {element: 1 for element in elements} # dict<node, weight>
def make_set(self, value):
"""
Add a new value as a disjoint subset if not already exists
"""
if value in self.__parents:
return
self.__parents[value] = value
self.__weights[value] = 1
def find(self, v):
"""
Find iteratively with path compression
"""
self.__is_element_valid(v)
root = v
while not self.__parents[root] == root:
root = self.__parents[root]
# path compression step
while v != root:
self.__parents[v], v = root, self.__parents[v]
return root
def union(self, v1, v2):
"""
Makes v2 the parent of v1
"""
self.__is_element_valid(v1)
self.__is_element_valid(v2)
p1 = self.find(v1)
p2 = self.find(v2)
# already in same set
if p1 == p2:
return
# swap p1 and p2 if size of p1 more than p2, as we will make p2 parent of p1
if self.__weights[p1] > self.__weights[p2]:
p1, p2 = p2, p1
# make p2 parent of p1 and update size of p2
self.__parents[p1] = p2
self.__weights[p2] += self.__weights[p1]
def __is_element_valid(self, v):
if v not in self.__parents:
raise ValueError(f"Invalid set element: {v}")
class DisjointSetTest(unittest.TestCase):
def test_1(self):
d = DisjointSet(["A", "B", "C", "D", "E"])
# Initially each element's parent is itself
self.assertEqual(d.find("A"), "A")
self.assertEqual(d.find("B"), "B")
# -1 is not present in set and hence raises ValueError
with self.assertRaises(ValueError):
d.find("X")
# create set {"A","B"}
d.union("A", "B")
self.assertEqual(d.find("A"), d.find("B"))
# create set {"C", "D"}
d.union("C", "D")
self.assertEqual(d.find("C"), d.find("D"))
# create set {"A", "B", "C", "D"}
d.union("A", "D")
# after union of "A" and "D", "B" and "C" should also be in the same set
self.assertEqual(d.find("B"), d.find("C"))
d.make_set("P")
self.assertEqual(d.find("P"), d.find("P"))
d.union("E", "P")
self.assertEqual(d.find("E"), d.find("P"))
self.assertNotEqual(d.find("A"), d.find("P"))
_ = unittest.TextTestRunner(verbosity=2).run(unittest.defaultTestLoader.loadTestsFromTestCase(DisjointSetTest))
test_1 (__main__.DisjointSetTest) ...
ok
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK